{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "digits = datasets.load_digits()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['data', 'target', 'target_names', 'images', 'DESCR'])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "digits.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".. _digits_dataset:\n",
      "\n",
      "Optical recognition of handwritten digits dataset\n",
      "--------------------------------------------------\n",
      "\n",
      "**Data Set Characteristics:**\n",
      "\n",
      "    :Number of Instances: 5620\n",
      "    :Number of Attributes: 64\n",
      "    :Attribute Information: 8x8 image of integer pixels in the range 0..16.\n",
      "    :Missing Attribute Values: None\n",
      "    :Creator: E. Alpaydin (alpaydin '@' boun.edu.tr)\n",
      "    :Date: July; 1998\n",
      "\n",
      "This is a copy of the test set of the UCI ML hand-written digits datasets\n",
      "https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits\n",
      "\n",
      "The data set contains images of hand-written digits: 10 classes where\n",
      "each class refers to a digit.\n",
      "\n",
      "Preprocessing programs made available by NIST were used to extract\n",
      "normalized bitmaps of handwritten digits from a preprinted form. From a\n",
      "total of 43 people, 30 contributed to the training set and different 13\n",
      "to the test set. 32x32 bitmaps are divided into nonoverlapping blocks of\n",
      "4x4 and the number of on pixels are counted in each block. This generates\n",
      "an input matrix of 8x8 where each element is an integer in the range\n",
      "0..16. This reduces dimensionality and gives invariance to small\n",
      "distortions.\n",
      "\n",
      "For info on NIST preprocessing routines, see M. D. Garris, J. L. Blue, G.\n",
      "T. Candela, D. L. Dimmick, J. Geist, P. J. Grother, S. A. Janet, and C.\n",
      "L. Wilson, NIST Form-Based Handprint Recognition System, NISTIR 5469,\n",
      "1994.\n",
      "\n",
      ".. topic:: References\n",
      "\n",
      "  - C. Kaynak (1995) Methods of Combining Multiple Classifiers and Their\n",
      "    Applications to Handwritten Digit Recognition, MSc Thesis, Institute of\n",
      "    Graduate Studies in Science and Engineering, Bogazici University.\n",
      "  - E. Alpaydin, C. Kaynak (1998) Cascading Classifiers, Kybernetika.\n",
      "  - Ken Tang and Ponnuthurai N. Suganthan and Xi Yao and A. Kai Qin.\n",
      "    Linear dimensionalityreduction using relevance weighted LDA. School of\n",
      "    Electrical and Electronic Engineering Nanyang Technological University.\n",
      "    2005.\n",
      "  - Claudio Gentile. A New Approximate Maximal Margin Classification\n",
      "    Algorithm. NIPS. 2000.\n"
     ]
    }
   ],
   "source": [
    "print(digits.DESCR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1797, 64)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = digits.data\n",
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1797,)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = digits.target\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "digits.target_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,\n",
       "       2, 3, 4, 5, 6, 7, 8, 9, 0, 9, 5, 5, 6, 5, 0, 9, 8, 9, 8, 4, 1, 7,\n",
       "       7, 3, 5, 1, 0, 0, 2, 2, 7, 8, 2, 0, 1, 2, 6, 3, 3, 7, 3, 3, 4, 6,\n",
       "       6, 6, 4, 9, 1, 5, 0, 9, 5, 2, 8, 2, 0, 0, 1, 7, 6, 3, 2, 1, 7, 4,\n",
       "       6, 3, 1, 3, 9, 1, 7, 6, 8, 4, 3, 1])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.,  0.,  0., 13., 15., 10.,\n",
       "        15.,  5.,  0.,  0.,  3., 15.,  2.,  0., 11.,  8.,  0.,  0.,  4.,\n",
       "        12.,  0.,  0.,  8.,  8.,  0.,  0.,  5.,  8.,  0.,  0.,  9.,  8.,\n",
       "         0.,  0.,  4., 11.,  0.,  1., 12.,  7.,  0.,  0.,  2., 14.,  5.,\n",
       "        10., 12.,  0.,  0.,  0.,  0.,  6., 13., 10.,  0.,  0.,  0.],\n",
       "       [ 0.,  0.,  0., 12., 13.,  5.,  0.,  0.,  0.,  0.,  0., 11., 16.,\n",
       "         9.,  0.,  0.,  0.,  0.,  3., 15., 16.,  6.,  0.,  0.,  0.,  7.,\n",
       "        15., 16., 16.,  2.,  0.,  0.,  0.,  0.,  1., 16., 16.,  3.,  0.,\n",
       "         0.,  0.,  0.,  1., 16., 16.,  6.,  0.,  0.,  0.,  0.,  1., 16.,\n",
       "        16.,  6.,  0.,  0.,  0.,  0.,  0., 11., 16., 10.,  0.,  0.],\n",
       "       [ 0.,  0.,  0.,  4., 15., 12.,  0.,  0.,  0.,  0.,  3., 16., 15.,\n",
       "        14.,  0.,  0.,  0.,  0.,  8., 13.,  8., 16.,  0.,  0.,  0.,  0.,\n",
       "         1.,  6., 15., 11.,  0.,  0.,  0.,  1.,  8., 13., 15.,  1.,  0.,\n",
       "         0.,  0.,  9., 16., 16.,  5.,  0.,  0.,  0.,  0.,  3., 13., 16.,\n",
       "        16., 11.,  5.,  0.,  0.,  0.,  0.,  3., 11., 16.,  9.,  0.],\n",
       "       [ 0.,  0.,  7., 15., 13.,  1.,  0.,  0.,  0.,  8., 13.,  6., 15.,\n",
       "         4.,  0.,  0.,  0.,  2.,  1., 13., 13.,  0.,  0.,  0.,  0.,  0.,\n",
       "         2., 15., 11.,  1.,  0.,  0.,  0.,  0.,  0.,  1., 12., 12.,  1.,\n",
       "         0.,  0.,  0.,  0.,  0.,  1., 10.,  8.,  0.,  0.,  0.,  8.,  4.,\n",
       "         5., 14.,  9.,  0.,  0.,  0.,  7., 13., 13.,  9.,  0.,  0.],\n",
       "       [ 0.,  0.,  0.,  1., 11.,  0.,  0.,  0.,  0.,  0.,  0.,  7.,  8.,\n",
       "         0.,  0.,  0.,  0.,  0.,  1., 13.,  6.,  2.,  2.,  0.,  0.,  0.,\n",
       "         7., 15.,  0.,  9.,  8.,  0.,  0.,  5., 16., 10.,  0., 16.,  6.,\n",
       "         0.,  0.,  4., 15., 16., 13., 16.,  1.,  0.,  0.,  0.,  0.,  3.,\n",
       "        15., 10.,  0.,  0.,  0.,  0.,  0.,  2., 16.,  4.,  0.,  0.],\n",
       "       [ 0.,  0., 12., 10.,  0.,  0.,  0.,  0.,  0.,  0., 14., 16., 16.,\n",
       "        14.,  0.,  0.,  0.,  0., 13., 16., 15., 10.,  1.,  0.,  0.,  0.,\n",
       "        11., 16., 16.,  7.,  0.,  0.,  0.,  0.,  0.,  4.,  7., 16.,  7.,\n",
       "         0.,  0.,  0.,  0.,  0.,  4., 16.,  9.,  0.,  0.,  0.,  5.,  4.,\n",
       "        12., 16.,  4.,  0.,  0.,  0.,  9., 16., 16., 10.,  0.,  0.],\n",
       "       [ 0.,  0.,  0., 12., 13.,  0.,  0.,  0.,  0.,  0.,  5., 16.,  8.,\n",
       "         0.,  0.,  0.,  0.,  0., 13., 16.,  3.,  0.,  0.,  0.,  0.,  0.,\n",
       "        14., 13.,  0.,  0.,  0.,  0.,  0.,  0., 15., 12.,  7.,  2.,  0.,\n",
       "         0.,  0.,  0., 13., 16., 13., 16.,  3.,  0.,  0.,  0.,  7., 16.,\n",
       "        11., 15.,  8.,  0.,  0.,  0.,  1.,  9., 15., 11.,  3.,  0.],\n",
       "       [ 0.,  0.,  7.,  8., 13., 16., 15.,  1.,  0.,  0.,  7.,  7.,  4.,\n",
       "        11., 12.,  0.,  0.,  0.,  0.,  0.,  8., 13.,  1.,  0.,  0.,  4.,\n",
       "         8.,  8., 15., 15.,  6.,  0.,  0.,  2., 11., 15., 15.,  4.,  0.,\n",
       "         0.,  0.,  0.,  0., 16.,  5.,  0.,  0.,  0.,  0.,  0.,  9., 15.,\n",
       "         1.,  0.,  0.,  0.,  0.,  0., 13.,  5.,  0.,  0.,  0.,  0.],\n",
       "       [ 0.,  0.,  9., 14.,  8.,  1.,  0.,  0.,  0.,  0., 12., 14., 14.,\n",
       "        12.,  0.,  0.,  0.,  0.,  9., 10.,  0., 15.,  4.,  0.,  0.,  0.,\n",
       "         3., 16., 12., 14.,  2.,  0.,  0.,  0.,  4., 16., 16.,  2.,  0.,\n",
       "         0.,  0.,  3., 16.,  8., 10., 13.,  2.,  0.,  0.,  1., 15.,  1.,\n",
       "         3., 16.,  8.,  0.,  0.,  0., 11., 16., 15., 11.,  1.,  0.],\n",
       "       [ 0.,  0., 11., 12.,  0.,  0.,  0.,  0.,  0.,  2., 16., 16., 16.,\n",
       "        13.,  0.,  0.,  0.,  3., 16., 12., 10., 14.,  0.,  0.,  0.,  1.,\n",
       "        16.,  1., 12., 15.,  0.,  0.,  0.,  0., 13., 16.,  9., 15.,  2.,\n",
       "         0.,  0.,  0.,  0.,  3.,  0.,  9., 11.,  0.,  0.,  0.,  0.,  0.,\n",
       "         9., 15.,  4.,  0.,  0.,  0.,  9., 12., 13.,  3.,  0.,  0.]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "some_digit = X[666]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y[666]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAK4UlEQVR4nO3d3Ytc9R3H8c+nq9LGp4UmFElCR0UCUuhGhoAEjIltiVVMLnqRgGJCwZsqSguivZH+A5JeFEGiG8FEaaMSEasVdG2F1rqJG2tcLWnYkq3aJJTFh0JD9NuLnUC0a/fMzHnaL+8XLO7DsL/vkH17ZmbPnp8jQgDy+FrTAwAoF1EDyRA1kAxRA8kQNZDMeVV80+XLl0en06niWzdqbm6u1vVmZmZqW2tkZKS2ta688sra1lq2bFlta9VpZmZGp06d8kJfqyTqTqejycnJKr51ow4cOFDrerfffntta42Ojta21r59+2pba2xsrLa16tTtdr/yazz8BpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSKRS17c2237N91PZ9VQ8FYHCLRm17RNKvJN0o6WpJ221fXfVgAAZT5Ei9TtLRiDgWEaclPSlpS7VjARhUkahXSjp+zsezvc99ge07bE/anjx58mRZ8wHoU5GoF/rzrv+5WmFEPBwR3YjorlixYvjJAAykSNSzklaf8/EqSe9XMw6AYRWJ+g1JV9m+3PYFkrZJerbasQAMatGLJETEGdt3SnpR0oikRyPiSOWTARhIoSufRMTzkp6veBYAJeCMMiAZogaSIWogGaIGkiFqIBmiBpIhaiCZSnboyOqBBx5oeoTKbN26tba1rr/++trWmpqaqm0taX53mqZxpAaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIJkiO3Q8avuE7bfrGAjAcIocqfdI2lzxHABKsmjUEfF7Sf+qYRYAJSjtOTXb7gDtUFrUbLsDtAOvfgPJEDWQTJFfaT0h6Y+S1tietf3j6scCMKgie2ltr2MQAOXg4TeQDFEDyRA1kAxRA8kQNZAMUQPJEDWQzJLfdmdiYqK2tQ4fPlzbWpK0YcOG2tbatWtXbWvNzc3VtladPx+StGPHjlrXWwhHaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogGaIGkilyjbLVtl+xPW37iO276xgMwGCKnPt9RtLPIuKQ7YslHbT9UkS8U/FsAAZQZNudDyLiUO/9jyVNS1pZ9WAABtPXc2rbHUlrJb2+wNfYdgdogcJR275I0lOS7omIj778dbbdAdqhUNS2z9d80Hsj4ulqRwIwjCKvflvSI5KmI+LB6kcCMIwiR+r1km6TtMn2VO/thxXPBWBARbbdeU2Sa5gFQAk4owxIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZNhLq8XGxsaaHqESnU6ntrXYSwvAkkfUQDJEDSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRT5MKDX7f9Z9uHe9vu/KKOwQAMpshpov+RtCkiPuldKvg127+NiD9VPBuAARS58GBI+qT34fm9t6hyKACDK3ox/xHbU5JOSHopIth2B2ipQlFHxGcRMSZplaR1tr+zwG3Ydgdogb5e/Y6IOUkTkjZXMg2AoRV59XuF7dHe+9+Q9D1J71Y9GIDBFHn1+zJJj9ke0fz/BH4dEc9VOxaAQRV59fstze9JDWAJ4IwyIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpJZ8tvujI6O1rbWpZdeWttakrRx48Za16tLnVvh1Pnz0RYcqYFkiBpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSKZw1L0L+r9pm4sOAi3Wz5H6bknTVQ0CoBxFt91ZJekmSburHQfAsIoeqXdJulfS5191A/bSAtqhyA4dN0s6EREH/9/t2EsLaIciR+r1km6xPSPpSUmbbD9e6VQABrZo1BFxf0SsioiOpG2SXo6IWyufDMBA+D01kExflzOKiAnNb2ULoKU4UgPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJLPltd+rU6XRqXW/Lli21rXXgwIHa1nr11VdrW2t8fLy2tdqCIzWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kUOk20dyXRjyV9JulMRHSrHArA4Po593tjRJyqbBIApeDhN5BM0ahD0u9sH7R9x0I3YNsdoB2KRr0+Iq6RdKOkn9i+7ss3YNsdoB0KRR0R7/f+e0LSM5LWVTkUgMEV2SDvQtsXn31f0g8kvV31YAAGU+TV729Jesb22dvvi4gXKp0KwMAWjToijkn6bg2zACgBv9ICkiFqIBmiBpIhaiAZogaSIWogGaIGknFElP5Nu91uTE5Olv59m9Y7Aac2GzZsqG2tqamp2taqc/uiiYmJ2taSpNHR0VrW6Xa7mpycXPAHkiM1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJFIra9qjt/bbftT1t+9qqBwMwmKLb7vxS0gsR8SPbF0haVuFMAIawaNS2L5F0naQdkhQRpyWdrnYsAIMq8vD7CkknJY3bftP27t71v7+AbXeAdigS9XmSrpH0UESslfSppPu+fCO23QHaoUjUs5JmI+L13sf7NR85gBZaNOqI+FDScdtrep+6QdI7lU4FYGBFX/2+S9Le3ivfxyTtrG4kAMMoFHVETEnqVjwLgBJwRhmQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRQ9owySxsfHa11v5876Ttyrc9+uPXv21LZWXXtbtQlHaiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogmUWjtr3G9tQ5bx/ZvqeO4QD0b9HTRCPiPUljkmR7RNI/JD1T8VwABtTvw+8bJP0tIv5exTAAhtdv1NskPbHQF9h2B2iHwlH3rvl9i6TfLPR1tt0B2qGfI/WNkg5FxD+rGgbA8PqJeru+4qE3gPYoFLXtZZK+L+npascBMKyi2+78W9I3K54FQAk4owxIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZBwR5X9T+6Skfv88c7mkU6UP0w5Z7xv3qznfjogF/3KqkqgHYXsyIrpNz1GFrPeN+9VOPPwGkiFqIJk2Rf1w0wNUKOt94361UGueUwMoR5uO1ABKQNRAMq2I2vZm2+/ZPmr7vqbnKYPt1bZfsT1t+4jtu5ueqUy2R2y/afu5pmcpk+1R2/ttv9v7t7u26Zn61fhz6t4GAX/V/OWSZiW9IWl7RLzT6GBDsn2ZpMsi4pDtiyUdlLR1qd+vs2z/VFJX0iURcXPT85TF9mOS/hARu3tX0F0WEXNNz9WPNhyp10k6GhHHIuK0pCclbWl4pqFFxAcRcaj3/seSpiWtbHaqctheJekmSbubnqVMti+RdJ2kRyQpIk4vtaCldkS9UtLxcz6eVZIf/rNsdyStlfR6s5OUZpekeyV93vQgJbtC0klJ472nFrttX9j0UP1qQ9Re4HNpfs9m+yJJT0m6JyI+anqeYdm+WdKJiDjY9CwVOE/SNZIeioi1kj6VtORe42lD1LOSVp/z8SpJ7zc0S6lsn6/5oPdGRJbLK6+XdIvtGc0/Vdpk+/FmRyrNrKTZiDj7iGq/5iNfUtoQ9RuSrrJ9ee+FiW2Snm14pqHZtuafm01HxINNz1OWiLg/IlZFREfz/1YvR8StDY9Vioj4UNJx22t6n7pB0pJ7YbPQdb+rFBFnbN8p6UVJI5IejYgjDY9VhvWSbpP0F9tTvc/9PCKeb3AmLO4uSXt7B5hjknY2PE/fGv+VFoByteHhN4ASETWQDFEDyRA1kAxRA8kQNZAMUQPJ/Bf3mrPQmWo6XAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "some_digit_image = some_digit.reshape(8, 8)\n",
    "plt.imshow(some_digit_image, cmap = matplotlib.cm.binary)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def train_test_split(X, y, test_ratio=0.2, seed=None):\n",
    "    \"\"\"将X和y按照test_ratio分割成X_train，X_test，y_train，y_test\"\"\"\n",
    "    \n",
    "    assert X.shape[0] == y.shape[0], \"the size of X must be equal to the size of y\"\n",
    "    assert 0.0 <= test_ratio <= 1.0, \"test_ration must be valid\"\n",
    "    \n",
    "    if seed:\n",
    "        np.random.seed(seed)\n",
    "        \n",
    "    shuffle_indexes = np.random.permutation(len(X))\n",
    "    \n",
    "    test_size = int(len(X) * test_ratio)\n",
    "    test_indexes = shuffle_indexes[:test_size]\n",
    "    train_indexes = shuffle_indexes[test_size:]\n",
    "    \n",
    "    X_train = X[train_indexes]\n",
    "    y_train = y[train_indexes]\n",
    "\n",
    "    X_test = X[test_indexes]\n",
    "    y_test = y[test_indexes]\n",
    "    \n",
    "    return X_train, X_test, y_train, y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_ratio=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from math import sqrt\n",
    "from collections import Counter\n",
    "\n",
    "class KNNClassifier:\n",
    "    \n",
    "    def __init__(self, k):\n",
    "        \"\"\"初始化kNN分类器\"\"\"\n",
    "        assert k >= 1, \"K must be valid!\"\n",
    "        self.k = k\n",
    "        self._X_train = None\n",
    "        self._y_train = None\n",
    "        \n",
    "    def fit(self, X_train, y_train):\n",
    "        \"\"\"根据训练数据集X_train和y_train训练kNN分类器\"\"\"\n",
    "        assert X_train.shape[0] == y_train.shape[0], \"the size of X_train must equal to the size of y_train\"\n",
    "        assert self.k <= X_train.shape[0], \"the size of X_train must be at least k\"\n",
    "    \n",
    "        self._X_train = X_train\n",
    "        self._y_train = y_train\n",
    "        return self\n",
    "    \n",
    "    def predict(self, X_predict):\n",
    "        \"\"\"给定待预测数据集X_predict, 返回表示X_predict的结果向量\"\"\"\n",
    "        assert self._X_train is not None and self._X_train is not None, \"must fit before predict\"\n",
    "        assert self._X_train.shape[1] == X_predict.shape[1], \"the feature number of X_predict must be equal to X_train\"\n",
    "        \n",
    "        y_predict = [self._predict(x) for x in X_predict]\n",
    "        return np.array(y_predict)\n",
    "    \n",
    "    def _predict(self, x):\n",
    "        \"\"\"给定单个待测数据x，返回x的预测结果\"\"\"\n",
    "        assert self._X_train.shape[1] == x.shape[0], \"the feature number of x must be equal to X_train\"\n",
    "        \n",
    "        distances = [sqrt(np.sum((x_train-x)**2)) for x_train in self._X_train]\n",
    "        nearst = np.argsort(distances)\n",
    "    \n",
    "        topK_y = [self._y_train[i] for i in nearst[:self.k]]\n",
    "        votes = Counter(topK_y)\n",
    "    \n",
    "        return votes.most_common(1)[0][0]\n",
    "    \n",
    "    def __repr__(self):\n",
    "        return \"KNN(k=%d)\" % self.k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_knn_clf = KNNClassifier(k=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNN(k=3)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_knn_clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_predict = my_knn_clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([9, 2, 1, 7, 3, 3, 2, 4, 6, 6, 2, 2, 5, 6, 6, 5, 6, 7, 7, 9, 7, 1,\n",
       "       4, 9, 7, 8, 8, 7, 6, 3, 0, 3, 3, 6, 4, 0, 2, 0, 1, 5, 8, 4, 9, 4,\n",
       "       3, 5, 6, 0, 6, 5, 3, 4, 9, 0, 6, 8, 1, 4, 7, 5, 0, 6, 5, 1, 0, 8,\n",
       "       4, 7, 8, 6, 4, 9, 4, 9, 2, 0, 6, 0, 7, 5, 5, 4, 3, 4, 0, 7, 2, 2,\n",
       "       6, 5, 7, 2, 1, 2, 0, 6, 2, 1, 6, 0, 5, 0, 5, 5, 3, 2, 5, 3, 5, 7,\n",
       "       3, 5, 1, 8, 2, 4, 5, 3, 9, 3, 4, 5, 4, 5, 6, 4, 4, 4, 6, 5, 9, 0,\n",
       "       3, 7, 8, 7, 6, 9, 0, 0, 8, 5, 0, 8, 1, 7, 4, 0, 3, 0, 8, 4, 7, 4,\n",
       "       3, 7, 4, 8, 7, 1, 4, 0, 0, 1, 4, 2, 6, 6, 7, 7, 0, 4, 6, 6, 4, 3,\n",
       "       3, 8, 9, 7, 4, 5, 6, 6, 1, 4, 3, 7, 9, 0, 4, 2, 2, 6, 7, 6, 8, 9,\n",
       "       0, 9, 3, 1, 5, 1, 9, 3, 0, 5, 2, 4, 5, 2, 6, 4, 2, 2, 7, 7, 3, 6,\n",
       "       4, 4, 5, 0, 2, 5, 4, 5, 2, 1, 4, 1, 8, 0, 0, 9, 7, 0, 0, 9, 0, 3,\n",
       "       8, 9, 6, 9, 6, 8, 3, 9, 4, 0, 0, 4, 7, 4, 5, 7, 3, 5, 8, 7, 9, 7,\n",
       "       7, 8, 0, 2, 1, 0, 5, 0, 1, 3, 2, 4, 8, 6, 9, 2, 0, 4, 4, 6, 3, 4,\n",
       "       2, 0, 3, 3, 3, 3, 8, 1, 1, 1, 7, 0, 6, 6, 0, 4, 6, 5, 9, 3, 8, 9,\n",
       "       6, 9, 9, 8, 5, 4, 1, 4, 8, 5, 5, 1, 9, 4, 7, 3, 1, 4, 0, 5, 4, 0,\n",
       "       4, 5, 3, 0, 9, 5, 7, 3, 3, 5, 9, 4, 6, 6, 5, 7, 8, 5, 8, 8, 8, 1,\n",
       "       2, 5, 5, 0, 8, 3, 4])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9805013927576601"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(y_predict == y_test) / len(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def accuracy_score(y_true, y_predict):\n",
    "    '''计算y_true和y_predict之间的准确率'''\n",
    "    assert y_true.shape[0] == y_predict.shape[0], \"the size of y_true must be equal to the size of y_predict\"\n",
    "    \n",
    "    return sum(y_true == y_predict) / len(y_true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9805013927576601"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_test, y_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def accuracy_score(y_true, y_predict):\n",
    "    '''计算y_true和y_predict之间的准确率'''\n",
    "    assert y_true.shape[0] == y_predict.shape[0], \"the size of y_true must be equal to the size of y_predict\"\n",
    "    \n",
    "    return sum(y_true == y_predict) / len(y_true)\n",
    "\n",
    "import numpy as np\n",
    "from math import sqrt\n",
    "from collections import Counter\n",
    "\n",
    "class KNNClassifier:\n",
    "    \n",
    "    def __init__(self, k):\n",
    "        \"\"\"初始化kNN分类器\"\"\"\n",
    "        assert k >= 1, \"K must be valid!\"\n",
    "        self.k = k\n",
    "        self._X_train = None\n",
    "        self._y_train = None\n",
    "        \n",
    "    def fit(self, X_train, y_train):\n",
    "        \"\"\"根据训练数据集X_train和y_train训练kNN分类器\"\"\"\n",
    "        assert X_train.shape[0] == y_train.shape[0], \"the size of X_train must equal to the size of y_train\"\n",
    "        assert self.k <= X_train.shape[0], \"the size of X_train must be at least k\"\n",
    "    \n",
    "        self._X_train = X_train\n",
    "        self._y_train = y_train\n",
    "        return self\n",
    "    \n",
    "    def predict(self, X_predict):\n",
    "        \"\"\"给定待预测数据集X_predict, 返回表示X_predict的结果向量\"\"\"\n",
    "        assert self._X_train is not None and self._X_train is not None, \"must fit before predict\"\n",
    "        assert self._X_train.shape[1] == X_predict.shape[1], \"the feature number of X_predict must be equal to X_train\"\n",
    "        \n",
    "        y_predict = [self._predict(x) for x in X_predict]\n",
    "        return np.array(y_predict)\n",
    "    \n",
    "    def _predict(self, x):\n",
    "        \"\"\"给定单个待测数据x，返回x的预测结果\"\"\"\n",
    "        assert self._X_train.shape[1] == x.shape[0], \"the feature number of x must be equal to X_train\"\n",
    "        \n",
    "        distances = [sqrt(np.sum((x_train-x)**2)) for x_train in self._X_train]\n",
    "        nearst = np.argsort(distances)\n",
    "    \n",
    "        topK_y = [self._y_train[i] for i in nearst[:self.k]]\n",
    "        votes = Counter(topK_y)\n",
    "    \n",
    "        return votes.most_common(1)[0][0]\n",
    "    \n",
    "    def score(self, X_test, y_test):\n",
    "        \"\"\"根据测试集X_test和y_test确定当前模型的准确度\"\"\"\n",
    "        y_predict = self.predict(X_test)\n",
    "        return accuracy_score(y_test, y_predict)\n",
    "        \n",
    "    def __repr__(self):\n",
    "        return \"KNN(k=%d)\" % self.k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9805013927576601"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_knn_clf = KNNClassifier(k=3)\n",
    "my_knn_clf.fit(X_train, y_train)\n",
    "my_knn_clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### scikit-learn中的accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "knn_clf = KNeighborsClassifier(n_neighbors=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
       "                     metric_params=None, n_jobs=None, n_neighbors=3, p=2,\n",
       "                     weights='uniform')"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_predict = knn_clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9888888888888889"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "accuracy_score(y_test, y_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9888888888888889"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "knn_clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
